{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bc6171f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d8cc42e",
   "metadata": {},
   "source": [
    "# indexing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "febcdf80",
   "metadata": {},
   "outputs": [],
   "source": [
    "a=torch.rand(4,3,28,28) #batch_size, chanel, H, W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2f05c572",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 28, 28])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "eb7a7982",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([28, 28])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[0,0].shape # 第0张图片第0个通道"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "426612ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3266)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[0,0,2,4] # 第0张图片第0个通道的第二行第四列的像素点"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d56bfe52",
   "metadata": {},
   "source": [
    "# select first/last N"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b69c351d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 28, 28])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "93e19ddf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3, 28, 28])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[:2].shape # 可以把：理解为->，前两张图片，所以dim1=2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fb35dc11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 28, 28])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[:2,:1,:,:].shape # 前两张图片第一个通道的所有图片的数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b30e291d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2, 28, 28])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[:2,1:,:,:].shape # 前两张图片，从第一个通道开始到最末尾（G,B通道，所以dim2=2）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8579959c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 28, 28])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[:2,-1:,:,:].shape # -1表示从末尾索引，-1：也就表示从最后一个索引到最后一个，所以dim2也就=1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4237a735",
   "metadata": {},
   "source": [
    "## selece by steps(对图片隔着采样)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ebc121b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 14, 14])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[:,:,0:28:2,0:28:2].shape # batch_size, chanel, H(隔行采样), W(隔行采样)\n",
    "# start : end : steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "98877cb9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 14, 14])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[:,:,::2,::2].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63e10718",
   "metadata": {},
   "source": [
    "# select by specific index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4e329c10",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 28, 28])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2e757586",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3, 28, 28])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.index_select(0,torch.tensor([0,2])).shape # 对第0个维度的0，1张图片进行索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f4696bec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 2, 28, 28])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.index_select(1,torch.tensor([1,2])).shape # 四张图片只取两个通道"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ac928e4",
   "metadata": {},
   "source": [
    "对于index_select函数，第二个参数不能直接给list，要将之转换为tensor才可以"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "5735ad8b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 28, 28])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.index_select(2,torch.arange(28)).shape # 取[0,28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "093287b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 8, 28])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.index_select(2,torch.arange(8)).shape # 取[0,7)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b5ca351",
   "metadata": {},
   "source": [
    "# ... (这不是语气，表示任意多的维度)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "11a63bee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 28, 28])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "c439dc97",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 28, 28])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[...].shape # 在这里代表4个："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "191721fd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 28, 28])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[0,...].shape # 在这里代表3个："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "808c0990",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 28, 28])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[:,1,...].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "523abc12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 28, 2])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[...,:2].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d5e730a",
   "metadata": {},
   "source": [
    "# select by mask(不推荐用，因为会把数据打平)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e5b32767",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0704, -0.9625,  0.3872, -1.1956],\n",
       "        [-0.6052,  0.5082,  1.0785,  1.0886],\n",
       "        [ 0.4051, -1.0605, -0.4139, -0.8271]])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.randn(3,4)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "86541440",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False, False, False, False],\n",
       "        [False,  True,  True,  True],\n",
       "        [False, False, False, False]])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask = x.ge(0.5) # 检查哪个元素大于0.5\n",
    "mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "ad10487d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.5082, 1.0785, 1.0886])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.masked_select(x, mask)  # 选出大于0.5的元素"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a839597f",
   "metadata": {},
   "source": [
    "### 之所以打平是因为>0.5的元素个数是根据内容才能确定的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "f3abf28d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.masked_select(x, mask).shape # dim为1，长度不定"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb2916a8",
   "metadata": {},
   "source": [
    "# select by flatten index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "27e1adf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "src=torch.tensor([[4,3,5],\n",
    "                 [6,7,8]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "4f007ab3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([4, 5, 8])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.take(src,torch.tensor([0,2,5])) # take会先将tensor打平，从[2,3]变为[5]，([4,3,5],[6,7,8]) -> ([4,3,5,6,7,8])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7346ceb9",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
